import uuid
from typing import TYPE_CHECKING, Dict, List, Optional

from sqlalchemy import JSON, ForeignKey, String
from sqlalchemy.orm import Mapped, mapped_column, relationship

from mirix.orm.sqlalchemy_base import SqlalchemyBase
from mirix.schemas.step import Step as PydanticStep

if TYPE_CHECKING:
    from mirix.orm.provider import Provider


class Step(SqlalchemyBase):
    """Tracks all metadata for agent step."""

    __tablename__ = "steps"
    __pydantic_model__ = PydanticStep

    id: Mapped[str] = mapped_column(String, primary_key=True, default=lambda: f"step-{uuid.uuid4()}")
    origin: Mapped[Optional[str]] = mapped_column(nullable=True, doc="The surface that this agent step was initiated from.")
    organization_id: Mapped[str] = mapped_column(
        ForeignKey("organizations.id", ondelete="RESTRICT"),
        nullable=True,
        doc="The unique identifier of the organization that this step ran for",
    )
    provider_id: Mapped[Optional[str]] = mapped_column(
        ForeignKey("providers.id", ondelete="RESTRICT"),
        nullable=True,
        doc="The unique identifier of the provider that was configured for this step",
    )
    provider_name: Mapped[Optional[str]] = mapped_column(None, nullable=True, doc="The name of the provider used for this step.")
    model: Mapped[Optional[str]] = mapped_column(None, nullable=True, doc="The name of the model used for this step.")
    context_window_limit: Mapped[Optional[int]] = mapped_column(
        None, nullable=True, doc="The context window limit configured for this step."
    )
    completion_tokens: Mapped[int] = mapped_column(default=0, doc="Number of tokens generated by the agent")
    prompt_tokens: Mapped[int] = mapped_column(default=0, doc="Number of tokens in the prompt")
    total_tokens: Mapped[int] = mapped_column(default=0, doc="Total number of tokens processed by the agent")
    completion_tokens_details: Mapped[Optional[Dict]] = mapped_column(JSON, nullable=True, doc="metadata for the agent.")
    tags: Mapped[Optional[List]] = mapped_column(JSON, doc="Metadata tags.")
    tid: Mapped[Optional[str]] = mapped_column(None, nullable=True, doc="Transaction ID that processed the step.")

    # Relationships (foreign keys)
    organization: Mapped[Optional["Organization"]] = relationship("Organization")

    # Relationships (backrefs)
    messages: Mapped[List["Message"]] = relationship("Message", back_populates="step", cascade="save-update", lazy="noload")
